package watch

import (
	"context"
	"database/sql"
	"errors"
	"strings"
	"time"

	"github.com/go-kit/kit/log"
	"github.com/go-kit/kit/log/level"
	"github.com/go-redis/redis/v8"
	"github.com/olivere/elastic/v7"
)

const (
	MessageIndex = "message"
)

type WatchRepository interface {
	SearchWatchedChatroom(ctx context.Context, opt searchWatchedChatroomRequest) (int64, []*WatchedChatroom, error)
	ModifyWatchedChatroom(ctx context.Context, opt modifyWatchedChatroomRequest) error
	ModifyWatchedMember(ctx context.Context, opt modifyWatchedMemberRequest) error
	SearchWatchedMember(ctx context.Context, opt searchWatchedMemberRequest) (int64, []*WatchedMember, error)
	MembersMessageCount7(_ context.Context, ids []string) (map[string]int64, error)
	ChatroomMessageCount7(_ context.Context, ids []string) (map[string]int64, error)
	GetMemberWatchedStatus(_ context.Context, memberId string, userId string) (bool, error)
	GetChatroomWatchedStatus(_ context.Context, chatroomId string, userId string) (bool, error)
}

type watchRepository struct {
	mysqlClient   *sql.DB
	redisClient   *redis.Client
	elasticClient *elastic.Client
	logger        log.Logger
}

type WatchedChatroomRow struct {
	Id            sql.NullString
	Name          sql.NullString
	MemberCount   sql.NullInt64
	OwnerId       sql.NullString
	OwnerNickname sql.NullString
	DataSource    sql.NullString
	Category      sql.NullString
	AddTimestamp  sql.NullInt64
}

func NewWatchRepository(mysqlClient *sql.DB, redisClient *redis.Client, elasticClient *elastic.Client, logger log.Logger) WatchRepository {
	return &watchRepository{
		mysqlClient:   mysqlClient,
		redisClient:   redisClient,
		elasticClient: elasticClient,
		logger:        logger,
	}
}

func (r *watchRepository) SearchWatchedChatroom(ctx context.Context, opt searchWatchedChatroomRequest) (int64, []*WatchedChatroom, error) {
	sqlRow := `
	SELECT 
	t1.chatroomId,
	t2.name,
	t5.c,
	t4.id,
	t4.name,
	t2.dataSource,
	t3.name,
	t1.addTimestamp 
	FROM 
	(SELECT * FROM chatroom_watched WHERE userId=?) t1
	LEFT JOIN 
	chatroom t2
	ON t1.chatroomId = t2.id
	LEFT JOIN
	category t3
	ON t2.categoryId = t3.id
	LEFT JOIN
	member t4
	ON t2.ownerId = t4.id
	LEFT JOIN
	(SELECT COUNT(memberId) c, chatroomId FROM chatroom_member GROUP BY chatroomId) t5
	ON t1.chatroomId = t5.chatroomId`
	sqlCount := `SELECT COUNT(chatroomId) FROM chatroom_watched WHERE userId=?`
	var (
		query []string
		args  []interface{}
		count int64
	)
	args = append(args, opt.UserId)
	if opt.DataSource != "" {
		query = append(query, " t2.dataSource=? ")
		args = append(args, opt.DataSource)
	}
	if opt.Name != "" {
		query = append(query, " t2.name like ? ")
		args = append(args, "%"+opt.Name+"%")
	}
	if opt.StartTimestamp > 0 && opt.EndTimestamp > 0 && opt.EndTimestamp > opt.StartTimestamp {
		query = append(query, " t1.addTimestamp>=? ", " t1.addTimestamp<=? ")
		args = append(args, opt.StartTimestamp, opt.EndTimestamp)
	}
	if opt.CategoryId != "" {
		query = append(query, " t2.categoryId=? ")
		args = append(args, opt.CategoryId)
	}
	if opt.LowMemberCount >= 0 && opt.HighMemberCount >= 0 && opt.LowMemberCount < opt.HighMemberCount {
		query = append(query, " t5.c>=? ", " t5.c<=? ")
		args = append(args, opt.LowMemberCount, opt.HighMemberCount)
	} else if opt.LowMemberCount >= 0 && opt.HighMemberCount == 0 {
		query = append(query, "t5.c>=?")
		args = append(args, opt.LowMemberCount)
	}
	if len(args) > 1 {
		sqlRow += " WHERE " + strings.Join(query, " AND ")
	}
	if opt.Size > 0 && opt.Page > 0 {
		sqlRow += " LIMIT ?,?"
		beforeCount := opt.Size * (opt.Page - 1)
		args = append(args, beforeCount, opt.Size)
	}
	rows, err := r.mysqlClient.Query(sqlRow, args...)
	if err != nil {
		return 0, nil, err
	}
	var chatrooms []*WatchedChatroom
	for rows.Next() {
		var row WatchedChatroomRow
		err = rows.Scan(
			&row.Id,
			&row.Name,
			&row.MemberCount,
			&row.OwnerId,
			&row.OwnerNickname,
			&row.DataSource,
			&row.Category,
			&row.AddTimestamp,
		)
		if err != nil {
			_ = r.logger.Log("Scan error", err)
		} else {
			chatrooms = append(chatrooms, WatchedChatroomToWatchedChatroom(row))
		}
	}
	err = r.mysqlClient.QueryRow(sqlCount, opt.UserId).Scan(&count)
	if err != nil {
		return 0, nil, err
	}
	return count, chatrooms, nil
}

type WatchedMemberRow struct {
	Id           sql.NullString
	Nickname     sql.NullString
	DataSource   sql.NullString
	Address      sql.NullString
	AddTimestamp sql.NullInt64
}

func (r *watchRepository) SearchWatchedMember(_ context.Context, opt searchWatchedMemberRequest) (int64, []*WatchedMember, error) {
	var count int64
	sqlRow := `select distinct id, nickname, dataSource, address, mw.addTimestamp 
	from member as m 
	inner join member_watched as mw 
	on mw.memberId=m.id`
	sqlCount := `SELECT COUNT(memberId) FROM member_watched WHERE userId=?`
	var (
		query []string
		args  []interface{}
	)
	if len(opt.UserId) > 0 {
		query = append(query, " mw.userId = ?")
		args = append(args, opt.UserId)
	}
	if opt.StartTimestamp != 0 && opt.EndTimestamp != 0 {
		query = append(query, " addTimestamp >= ?")
		query = append(query, " addTimestamp <= ?")
		args = append(args, opt.StartTimestamp, opt.EndTimestamp)
	}
	if len(opt.DataSource) > 0 {
		query = append(query, " dataSource = ?")
		args = append(args, opt.DataSource)
	}
	if len(opt.Nickname) > 0 {
		query = append(query, " nickname like ?")
		args = append(args, "%"+opt.Nickname+"%")
	}
	if len(opt.Address) > 0 {
		query = append(query, " address = ?")
		args = append(args, opt.Address)
	}
	if len(opt.Comment) > 0 {
		query = append(query, " comment like ?")
		args = append(args, "%"+opt.Comment+"%")
	}
	if opt.Size == 0 {
		opt.Size = 10
	}
	if opt.Page == 0 {
		opt.Page = 1
	}
	sqlRow += " WHERE" + strings.Join(query, " AND ") + " ORDER BY mw.addTimestamp desc" + " LIMIT ?,?"
	args = append(args, (opt.Page-1)*opt.Size, opt.Size)

	rows, err := r.mysqlClient.Query(sqlRow, args...)
	if err != nil {
		_ = r.logger.Log("get watched member by user error:", err)
		return 0, nil, errors.New("网络异常，请重试")
	}
	defer func() {
		err = rows.Close()
		if err != nil {
			_ = r.logger.Log("get watched member by user close error:", err)
		}
	}()
	watchMemberRes := make([]*WatchedMember, 0)
	for rows.Next() {
		var member WatchedMember
		err = rows.Scan(
			&member.Id,
			&member.Nickname,
			&member.DataSource,
			&member.Address,
			&member.AddTimestamp,
		)
		if err != nil {
			_ = r.logger.Log("get users by organizationId scan error:", err)
			continue
		}
		watchMemberRes = append(watchMemberRes, &member)
	}
	err = r.mysqlClient.QueryRow(sqlCount, opt.UserId).Scan(&count)
	return count, watchMemberRes, nil
}

func (r *watchRepository) ModifyWatchedChatroom(ctx context.Context, opt modifyWatchedChatroomRequest) error {
	if opt.Watched {
		sqlRow := `DELETE FROM chatroom_watched WHERE chatroomId=? AND userId=?`
		_, err := r.mysqlClient.ExecContext(ctx, sqlRow, opt.Id, opt.UserId)
		if err != nil {
			_ = level.Error(r.logger).Log("delete watched error:", err)
			return errors.New("清除关注状态失败")
		}
	} else {
		sqlRow := `INSERT IGNORE INTO chatroom_watched(chatroomId, userId, addTimestamp) VALUES(?, ?, ?);`
		now := time.Now().UnixNano() / 1e6
		_, err := r.mysqlClient.ExecContext(ctx, sqlRow, opt.Id, opt.UserId, now)
		if err != nil {
			_ = level.Error(r.logger).Log("set watched error:", err)
			return errors.New("设置关注状态失败")
		}
	}
	return nil
}

func (r *watchRepository) ModifyWatchedMember(ctx context.Context, opt modifyWatchedMemberRequest) error {
	if opt.Watched {
		sqlRow := `DELETE FROM member_watched WHERE memberId=? AND userId=?`
		_, err := r.mysqlClient.ExecContext(ctx, sqlRow, opt.Id, opt.UserId)
		if err != nil {
			_ = level.Error(r.logger).Log("delete watched error:", err)
			return errors.New("清除关注状态失败")
		}
	} else {
		sqlRow := `INSERT IGNORE INTO member_watched(memberId, userId, addTimestamp) VALUES(?, ?, ?);`
		now := time.Now().UnixNano() / 1e6
		_, err := r.mysqlClient.ExecContext(ctx, sqlRow, opt.Id, opt.UserId, now)
		if err != nil {
			_ = level.Error(r.logger).Log("set watched error:", err)
			return errors.New("设置关注状态失败")
		}
	}
	return nil
}

func WatchedChatroomToWatchedChatroom(row WatchedChatroomRow) *WatchedChatroom {
	return &WatchedChatroom{
		Id:            row.Id.String,
		Name:          row.Name.String,
		MemberCount:   row.MemberCount.Int64,
		OwnerNickname: row.OwnerNickname.String,
		OwnerId:       row.OwnerId.String,
		DataSource:    row.DataSource.String,
		Category:      row.Category.String,
		AddTimestamp:  row.AddTimestamp.Int64,
	}
}

func (r *watchRepository) MembersMessageCount7(_ context.Context, ids []string) (map[string]int64, error) {
	var query = elastic.NewBoolQuery()
	var aggQuery = elastic.NewTermsAggregation().Field("member.id").Size(len(ids))
	for _, id := range ids {
		query.Should(elastic.NewTermQuery("member.id", id))
	}
	resp, err := r.elasticClient.
		Search(MessageIndex).
		Query(query).
		Aggregation("groupByMemberId", aggQuery).
		Do(context.Background())
	if err != nil {
		_ = level.Error(r.logger).Log("elasticsearch connection error:", err)
		return nil, errors.New("网络异常，请重试")
	}

	termsResult, found := resp.Aggregations.Terms("groupByMemberId")
	if !found {
		err = errors.New("not found")
		_ = level.Error(r.logger).Log("count members message:", err)
		return nil, err
	}
	result := make(map[string]int64, len(ids))
	for _, bucket := range termsResult.Buckets {
		result[bucket.Key.(string)] = bucket.DocCount
	}
	return result, nil
}

func (r *watchRepository) ChatroomMessageCount7(_ context.Context, ids []string) (map[string]int64, error) {
	var query = elastic.NewBoolQuery()
	var aggQuery = elastic.NewTermsAggregation().Field("chatroom.id").Size(len(ids))
	for _, id := range ids {
		query.Should(elastic.NewTermQuery("chatroom.id", id))
	}
	resp, err := r.elasticClient.
		Search(MessageIndex).
		Query(query).
		Aggregation("groupByChatroomId", aggQuery).
		Do(context.Background())
	if err != nil {
		_ = level.Error(r.logger).Log("elasticsearch connection error:", err)
		return nil, errors.New("网络异常，请重试")
	}

	termsResult, found := resp.Aggregations.Terms("groupByChatroomId")
	if !found {
		err = errors.New("not found")
		_ = level.Error(r.logger).Log("count chatroom message:", err)
		return nil, err
	}
	result := make(map[string]int64, len(ids))
	for _, bucket := range termsResult.Buckets {
		result[bucket.Key.(string)] = bucket.DocCount
	}
	return result, nil
}

type MemberWatchedStatus struct {
	MemberId     string
	UserId       string
	AddTimestamp int64
}

func (r *watchRepository) GetMemberWatchedStatus(_ context.Context, memberId string, userId string) (bool, error) {
	sqlRow := `select memberId, userId, addTimestamp from member_watched where memberId = ? and userId = ?;`
	var memberWatchedRow MemberWatchedStatus
	row := r.mysqlClient.QueryRow(sqlRow, memberId, userId)
	err := row.Scan(&memberWatchedRow.MemberId, &memberWatchedRow.UserId, &memberWatchedRow.AddTimestamp)
	if err != nil {
		if err == sql.ErrNoRows {
			// 无此关注信息
			return false, nil
		} else {
			_ = level.Error(r.logger).Log("get member watched scan error:", err)
			return false, err
		}
	}
	return true, nil
}

func (r *watchRepository) GetChatroomWatchedStatus(_ context.Context, memberId string, userId string) (bool, error) {
	sqlRow := `select chatroomId, userId, addTimestamp from chatroom_watched where chatroomId = ? and userId = ?;`
	var memberWatchedRow MemberWatchedStatus
	row := r.mysqlClient.QueryRow(sqlRow, memberId, userId)
	err := row.Scan(&memberWatchedRow.MemberId, &memberWatchedRow.UserId, &memberWatchedRow.AddTimestamp)
	if err != nil {
		if err == sql.ErrNoRows {
			// 无此关注信息
			return false, nil
		} else {
			_ = level.Error(r.logger).Log("get member watched scan error:", err)
			return false, err
		}
	}
	return true, nil
}
